package com.lee.algorithm.linkedlist;

import com.lee.algorithm.linkedlist.struct.OneWayNode;

/***
 * @description: 划分单项链表区间
 *      将单项链表按某值划分成左边小、中间相等、右边大的形式
 * @author : 青石路
 * @date: 2021/11/29 21:59
 */
public class PartitionListInterval {

    /**
     * 1、元素放入到数组中，在数组中进行partition，最后再将数组中的元素用单项链表串起来
     * @param head
     */
    public static void partitionInterval(OneWayNode<Integer> head, int target) {
        if (head == null) {
            return;
        }

        // 链表转数组
        OneWayNode cur = head;
        int len = 0;
        while(cur != null) {
            len ++;
            cur = cur.next;
        }
        OneWayNode<Integer>[] nodeArr = new OneWayNode[len];
        cur = head;
        for(int i=0; cur != null; i++) {
            nodeArr[i] = cur;
            cur = cur.next;
        }

        // 分区处理
        int ltp = -1;        // 小于区
        int gtp = len;      // 大于区
        int index = 0;
        while(index < gtp) {
            if (nodeArr[index].value < target) {
                swap(nodeArr, ++ltp, index++);      // 小于区的下一个元素与当前元素交换
            } else if (nodeArr[index].value > target) {
                swap(nodeArr, --gtp, index);        // 大于区的上一个元素与当前元素交换
            } else {
                index ++;
            }
        }

        // 数组转链表
        head = nodeArr[0];
        cur = head;
        for(index=1; index<len; index++) {
            cur.next = nodeArr[index];
            cur = cur.next;
        }
        cur.next = null;

        print(head);
    }

    /**
     * 2、6个变量，初始值都 = null（边界（null）要注意）
     *      sh：小于部分的head
     *      st：小于部分的tail
     *      eh：等于部分的head
     *      et：小于部分的tail
     *      bh：大于部分的head
     *      bt：大于部分的tail
     * @param head
     */
    public static void partitionIntervalPlus(OneWayNode<Integer> head, int target) {
        OneWayNode<Integer> sh = null, st = null, eh = null, et = null, bh = null, bt = null;
        OneWayNode<Integer> cur = head;

        while(cur != null) {
            if (cur.value < target) {
                if (sh == null) {
                    sh = cur;
                    st = cur;
                } else {
                    st.next = cur;
                    st = st.next;
                }
            } else if (cur.value > target) {
                if (bh == null) {
                    bh = cur;
                    bt = cur;
                } else {
                    bt.next = cur;
                    bt = bt.next;
                }
            } else {
                if (eh == null) {
                    eh = cur;
                    et = cur;
                } else {
                    et.next = cur;
                    et = et.next;
                }
            }

            cur = cur.next;
        }
        if (st != null) {
            st.next = null;
        }
        if (et != null) {
            et.next = null;
        }
        if (bt != null) {
            bt.next = null;
        }

        head = null;
        cur = null;
        while (sh != null) {
            if (head == null) {
                head = sh;
                cur = sh;
            } else {
                cur.next = sh;
                cur = cur.next;
            }
            sh = sh.next;
        }
        while (eh != null) {
            if (head == null) {
                head = eh;
                cur = sh;
            } else {
                cur.next = eh;
                cur = cur.next;
            }
            eh = eh.next;
        }
        while (bh != null) {
            if (head == null) {
                head = bh;
                cur = bh;
            } else {
                cur.next = bh;
                cur = cur.next;
            }
            bh = bh.next;
        }

        print(head);
    }

    public static void main(String[] args) {
        OneWayNode head = new OneWayNode(2);        // 此处的 head 始终指向 2 这个节点
        OneWayNode n1 = new OneWayNode(1);
        OneWayNode n2 = new OneWayNode(6);
        OneWayNode n3 = new OneWayNode(7);
        OneWayNode n4 = new OneWayNode(4);
        OneWayNode n5 = new OneWayNode(3);
        head.next = n1;
        n1.next = n2;
        n2.next = n3;
        n3.next = n4;
        n4.next = n5;
        int target = 2;

        // partitionInterval(head, target);

        partitionIntervalPlus(head, target);
    }

    public static void swap(OneWayNode<Integer>[] nodeArr, int i, int j) {
        OneWayNode t = nodeArr[i];
        nodeArr[i] = nodeArr[j];
        nodeArr[j] = t;
    }

    public static void print(OneWayNode<Integer>[] nodeArr) {
        for(OneWayNode node : nodeArr) {
            System.out.print(node.value + " ");
        }
        System.out.println();
    }

    private static void print(OneWayNode<Integer> head) {
        OneWayNode cur = head;
        while(cur != null) {
            System.out.print(cur.value + " ");
            cur = cur.next;
        }
        System.out.println();
    }
}
